package com.example.djlocrspringboot.djlocr.djl.utils.recognition;

import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.bytedeco.opencv.opencv_core.Point2f;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import com.example.djlocrspringboot.djlocr.djl.utils.common.RotatedBox;
import com.example.djlocrspringboot.djlocr.djl.utils.opencv.NDArrayUtils;
import com.example.djlocrspringboot.djlocr.djl.utils.opencv.OpenCVUtils;

import ai.djl.Device;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.opencv.OpenCVImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

@Component
public class OcrV3Recognition {

    private static final Logger logger = LoggerFactory.getLogger(OcrV3Recognition.class);

    public OcrV3Recognition() {
    }

    public Criteria<Image, String> recognizeCriteria() {
        Criteria<Image, String> criteria =
                Criteria.builder()
                        .optEngine("PaddlePaddle")
                        .optDevice(Device.gpu())
                        .setTypes(Image.class, String.class)
                        .optModelUrls("/guige/ch_PP-OCRv3_rec_infer.zip")
                        .optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap<String, String>())))
                        .optProgress(new ProgressBar())
                        .build();
    
        return criteria;
    }

    public List<RotatedBox> predict(
            Image image, Predictor<Image, NDList> detector, Predictor<Image, String> recognizer)
            throws TranslateException {
        try (NDList boxes = detector.predict(image)) {
            List<RotatedBox> result = new ArrayList<>();
            for (int i = 0; i < boxes.size(); i++) {
                try (NDArray box = boxes.get(i);) {
                    Image subImg = get_rotate_crop_image(image, box);
                    if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
                        subImg = rotateImg(subImg);
                    }
                    String name = recognizer.predict(subImg);
                    RotatedBox rotatedBox = new RotatedBox(box, name);
                    result.add(rotatedBox);
                } catch (TranslateException e) {
                    throw new RuntimeException(e);
                }
            }
            return result;
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
    }
    
    public String predict2Str(
            Image image, Predictor<Image, NDList> detector, Predictor<Image, String> recognizer)
            throws TranslateException {
        StringBuffer sb = new StringBuffer();
        try (NDList boxes = detector.predict(image)) {
            for (int i = 0; i < boxes.size(); i++) {
                try (NDArray box = boxes.get(i);) {
                    Image subImg = get_rotate_crop_image(image, box);
                    if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
                        subImg = rotateImg(subImg);
                    }
                    String name = recognizer.predict(subImg);
                    sb.append(name);
                } catch (TranslateException e) {
                    throw new RuntimeException(e);
                }
            }
            return sb.toString();
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
    }
    
    private Image get_rotate_crop_image(Image image, NDArray box) {
        float[] pointsArr = box.toFloatArray();
        float[] lt = java.util.Arrays.copyOfRange(pointsArr, 0, 2);
        float[] rt = java.util.Arrays.copyOfRange(pointsArr, 2, 4);
        float[] rb = java.util.Arrays.copyOfRange(pointsArr, 4, 6);
        float[] lb = java.util.Arrays.copyOfRange(pointsArr, 6, 8);
        int img_crop_width = (int) Math.max(distance(lt, rt), distance(rb, lb));
        int img_crop_height = (int) Math.max(distance(lt, lb), distance(rt, rb));
        List<Point> srcPoints = new ArrayList<>();
        srcPoints.add(new Point(lt[0], lt[1]));
        srcPoints.add(new Point(rt[0], rt[1]));
        srcPoints.add(new Point(rb[0], rb[1]));
        srcPoints.add(new Point(lb[0], lb[1]));
        List<Point> dstPoints = new ArrayList<>();
        dstPoints.add(new Point(0, 0));
        dstPoints.add(new Point(img_crop_width, 0));
        dstPoints.add(new Point(img_crop_width, img_crop_height));
        dstPoints.add(new Point(0, img_crop_height));
    
        try (Point2f srcPoint2f = NDArrayUtils.toOpenCVPoint2f(srcPoints, 4);
             Point2f dstPoint2f = NDArrayUtils.toOpenCVPoint2f(dstPoints, 4);) {
            BufferedImage bufferedImage = OpenCVUtils.matToBufferedImage((org.opencv.core.Mat) image.getWrappedImage());
            OpenCVFrameConverter.ToMat cv = new OpenCVFrameConverter.ToMat();
            org.bytedeco.opencv.opencv_core.Mat mat = cv.convertToMat(new Java2DFrameConverter().convert(bufferedImage));
            org.bytedeco.opencv.opencv_core.Mat dstMat = OpenCVUtils.perspectiveTransform(mat, srcPoint2f, dstPoint2f);
        
            OpenCVFrameConverter.ToMat converter1 = new OpenCVFrameConverter.ToMat();
            OpenCVFrameConverter.ToOrgOpenCvCoreMat converter2 = new OpenCVFrameConverter.ToOrgOpenCvCoreMat();
            org.opencv.core.Mat cvmat = converter2.convert(converter1.convert(dstMat));
            Image subImg = OpenCVImageFactory.getInstance().fromImage(cvmat);
            subImg = subImg.getSubImage(0, 0, img_crop_width, img_crop_height);
            return subImg;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private float distance(float[] point1, float[] point2) {
        float disX = point1[0] - point2[0];
        float disY = point1[1] - point2[1];
        float dis = (float) Math.sqrt(disX * disX + disY * disY);
        return dis;
    }

    private Image rotateImg(Image image) {
        try (NDManager manager = NDManager.newBaseManager(); NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);) {
            return ImageFactory.getInstance().fromNDArray(rotated);
        }
    }
}
